import functools
import numpy as np
import os
import json
import acl
import argparse
from tools.dataloader import save_data, load_data
from tools.constant import ACL_MEMCPY_HOST_TO_DEVICE, \
    ACL_MEM_MALLOC_HUGE_ONLY, ACL_FORMAT_ND, ACL_FORMAT_NCHW, ACL_FORMAT_NHWC, ACL_FORMAT_NC1HWC0, \
    ACL_FORMAT_FRACTAL_Z, ACL_FORMAT_FRACTAL_NZ, acl_dtype, ACL_ERROR_CODE, ACL_MEMCPY_DEVICE_TO_HOST, \
    ACL_FORMAT_HWCN, ACL_FORMAT_NCDHW, ACL_FORMAT_NDHWC, ACL_FORMAT_NDC1HWC0

ACL_ENGINE = 1

attr_funcs = {"bool": acl.op.set_attr_bool, "int": acl.op.set_attr_int, "float": acl.op.set_attr_float,
              "string": acl.op.set_attr_string, "list_bool": acl.op.set_attr_list_bool,
              "list_int": acl.op.set_attr_list_int, "list_float": acl.op.set_attr_list_float,
              "list_string": acl.op.set_attr_list_string, "list_list_int": acl.op.set_attr_list_list_int}

format_dict = {"ND": ACL_FORMAT_ND, "NCHW": ACL_FORMAT_NCHW, "NHWC": ACL_FORMAT_NHWC, "NC1HWC0": ACL_FORMAT_NC1HWC0,
               "FRACTAL_Z": ACL_FORMAT_FRACTAL_Z, "FRACTAL_NZ": ACL_FORMAT_FRACTAL_NZ, "HWCN": ACL_FORMAT_HWCN,
               "NCDHW": ACL_FORMAT_NCDHW, "NDC1HWC0": ACL_FORMAT_NDC1HWC0, "NDHWC": ACL_FORMAT_NDHWC}

ACL_PROF_ACL_API = 0x0001
ACL_PROF_TASK_TIME = 0x0002
ACL_PROF_AICORE_METRICS = 0x0004
ACL_PROF_AICPU_TRACE = 0x0008


def check_ret(message, ret):
    if ret != ACL_ERROR_CODE:
        raise Exception("{} failed ret={}".format(message, ret))


class AclOp(object):
    def __init__(self):
        self.context = None
        self.stream = None
        self.op_attr = None

        self._inputs_desc = []
        self._inputs_device = []
        self._inputs_device_buffer = []
        self._inputs_host_buffer = []

        self.output_desc = []
        self.device_outputs = []
        self.device_buffer_outputs = []
        self.host_outputs = []
        self.block_dim = None
        self.workspace_sizes = None
        self.workspace_num = 0
        self.kernel_name = None
        self.repeat_times = 1
        self.run_info = None
        self.is_dynamic = False

    def _check_data_info(self, inputs_data, inputs_info):
        len_data = len(inputs_data)
        len_info = len(inputs_info)
        if len_data != len_info:
            raise RuntimeError("length of inputs data and inputs info is not equal, the length of inputs data is %s,"
                               "inputs info is %s" % (len_data, len_info))

        for i in range(len_data):
            data = inputs_data[i]
            info = inputs_info[i]

            dtype = str(data.dtype)

            if dtype != info["type"] or list(data.shape) != list(info["shape"]):
                print("data.dtype: ", dtype)
                print("info.dtype: ", info["type"])
                print("data.shape: ", data.shape)
                print("info.shape: ", info["shape"])
                raise RuntimeError("dtype or shape of inputs data and inputs info should be same")

    def get_block_dim(self, kernel_name):
        jsonfile = "./kernel_meta/" + kernel_name + ".json"
        with open(jsonfile, "r") as json_file:
            params = json.load(json_file)
            if not self.run_info:
                self.block_dim = params.get("blockDim", 1)
            else:
                self.block_dim = self.run_info.get("block_dim", 1)

            self.workspace_sizes = params.get("workspace", {}).get("size", [])
            self.workspace_num = params.get("workspace", {}).get("num", 0)
            self.kernel_name = kernel_name

    def release(self, device_id):
        print("release source stage:")
        while self._inputs_desc:
            ret = acl.destroy_data_buffer(self._inputs_device_buffer.pop())
            check_ret("acl.destroy_data_buffer device", ret)
            ret = acl.rt.free(self._inputs_device.pop())
            check_ret("acl.rt.free", ret)
            acl.destroy_tensor_desc(self._inputs_desc.pop())

        while self.output_desc:
            ret = acl.destroy_data_buffer(self.device_buffer_outputs.pop())
            check_ret("acl.destroy_data_buffer", ret)
            ret = acl.rt.free(self.device_outputs.pop())
            check_ret("acl.rt.free", ret)
            acl.destroy_tensor_desc(self.output_desc.pop())

        while self.host_outputs:
            ret = acl.rt.free_host(self.host_outputs.pop())
            check_ret("acl.rt.free_host", ret)

        if self.op_attr:
            acl.op.destroy_attr(self.op_attr)
            self.op_attr = None

        if self.stream:
            ret = acl.rt.destroy_stream(self.stream)
            check_ret("acl.rt.destroy_stream", ret)
            self.stream = None

        if self.context:
            ret = acl.rt.destroy_context(self.context)
            check_ret("acl.rt.destroy_context", ret)
            self.context = None

        ret = acl.rt.reset_device(device_id)
        check_ret("acl.rt.reset_device", ret)
        self.tiling_args = []
        ret = acl.finalize()
        check_ret("acl.finalize", ret)
        print("release source success")

    def init_resource(self, device_id):
        print("init resource stage:")
        ret = acl.init()
        check_ret("acl.init", ret)

        ret = acl.rt.set_device(device_id)
        check_ret("acl.rt.set_device", ret)

        self.context, ret = acl.rt.create_context(device_id)
        check_ret("acl.rt.create_context", ret)

        self.stream, ret = acl.rt.create_stream()
        check_ret("acl.rt.create_stream", ret)

        self.op_attr = acl.op.create_attr()
        print("init resource success")

    def _gen_input_tensor(self, inputs_data, inputs_info):
        print("gen input data stage")
        bytes_list = []
        if not self.is_dynamic:
            self._check_data_info(inputs_data, inputs_info)

        length = len(inputs_info)
        for i in range(length):
            data_info = inputs_info[i]
            is_const = data_info.get("is_const", False)
            data = inputs_data[i]

            dtype = str(data.dtype)
            if dtype == "float32":
                dtype = "float"
            shape = list(data.shape)
            format = data_info["format"]
            format_type = format_dict[format]

            input_desc = acl.create_tensor_desc(acl_dtype[dtype], shape, format_type)
            input_size = acl.get_tensor_desc_size(input_desc)
            input_device, ret = acl.rt.malloc(input_size, ACL_MEM_MALLOC_HUGE_ONLY)
            check_ret("acl.rt.malloc", ret)
            bytes_data = data.tobytes()
            bytes_list.append(bytes_data)
            input_ptr = acl.util.bytes_to_ptr(bytes_data)
            ret = acl.rt.memcpy(input_device, input_size, input_ptr, input_size, ACL_MEMCPY_HOST_TO_DEVICE)

            check_ret("acl.rt.memcpy", ret)
            input_buffer = acl.create_data_buffer(input_device, input_size)

            if is_const:
                ret = acl.set_tensor_const(input_desc, input_ptr, input_size)
                check_ret("acl.set_tensor_const", ret)

            self._inputs_device.append(input_device)
            self._inputs_device_buffer.append(input_buffer)
            self._inputs_desc.append(input_desc)
        print("gen input data success")

    def _gen_output_tensor(self, outputs_info):
        print("gen output data stage")
        operator_output = []
        for info in outputs_info:
            dtype = info["type"]
            if dtype == "float32":
                dtype = "float"

            shape = list(info["shape"])
            format = info["format"]
            format_type = format_dict[format]
            output_desc = acl.create_tensor_desc(acl_dtype[dtype], shape, format_type)
            operator_output.append(output_desc)

        for desc in operator_output:
            output_size = acl.get_tensor_desc_size(desc)
            output_device, ret = acl.rt.malloc(output_size, ACL_MEM_MALLOC_HUGE_ONLY)
            check_ret("acl.rt.malloc", ret)
            self.device_outputs.append(output_device)
            self.device_buffer_outputs.append(acl.create_data_buffer(output_device, output_size))
            self.host_outputs.append(acl.rt.malloc_host(output_size)[0])
            self.output_desc.append(desc)
        print("gen output data success")

    def _process_attr(self, dtype, value):
        result = value
        if dtype == "bool":
            if value:
                result = 1
            else:
                result = 0
        if dtype == "list_bool":
            result = []
            for v in value:
                if v:
                    result.append(1)
                else:
                    result.append(0)

        def int_list(inner_v):
            return [int(j) for j in inner_v]

        if dtype == "list_int":
            result = [int(i) for i in value]
        if dtype == "list_float":
            result = [float(i) for i in value]
        if dtype == "list_list_int":
            result = [int_list(i) for i in value]

        return result

    def _gen_attr(self, attrs):
        for attr in attrs:
            name = attr["name"]
            dtype = attr["type"]
            value = self._process_attr(dtype, attr["value"])
            attr_func = attr_funcs[dtype]
            ret = attr_func(self.op_attr, name, value)
            check_ret("acl.set_attr_%s" % dtype, ret)

    def get_om_name(self, path):
        names = os.listdir(path)
        for name in names:
            if name.endswith(".om"):
                return name
        return None

    def load_op(self, kernel_name, op_model_path=None):
        if op_model_path:
            ret = acl.op.set_model_dir(op_model_path)
            check_ret("acl.op.set_model_dir", ret)
        else:
            self.load_kernel(kernel_name)

    def load_kernel(self, kernel_name):
        kernel_id = kernel_name + "__kernel0"
        kernel_path = "./kernel_meta/" + kernel_name + ".o"
        np_kernel = np.fromfile(kernel_path, dtype=np.byte)
        kernel_size = np_kernel.itemsize * np_kernel.size
        self.get_block_dim(kernel_name)

        ret = acl.op.register_compile_func(kernel_name, self.select_kernel)
        check_ret("acl.op.register_compile_func", ret)
        bytes_kernel = np_kernel.tobytes()
        ptr_kernel = acl.util.bytes_to_ptr(bytes_kernel)
        ret = acl.op.create_kernel(kernel_name, kernel_id, kernel_id, ptr_kernel, kernel_size, ACL_ENGINE, 0)
        check_ret("acl.op.create_kernel", ret)

    def select_kernel(self, in_num, in_desc, out_num, out_desc, op_attr, op_kernel_desc):
        tiling_data = []
        if self.run_info:
            tiling_data = self.run_info.get("tiling_data", [])

        args = np.array(tiling_data, dtype=np.int32)
        bytes_args = args.tobytes()
        args_ptr = acl.util.bytes_to_ptr(bytes_args)
        size = args.itemsize * args.size
        kernel_id = self.kernel_name + "__kernel0"
        ret = acl.op.set_kernel_args(op_kernel_desc, kernel_id, self.block_dim, args_ptr, size)
        check_ret("acl.op.set_kernel_args", ret)

        workspace_sizes = np.array(self.workspace_sizes, dtype=np.uint32)
        bytes_workspace = workspace_sizes.tobytes()
        workspace_sizes_ptr = acl.util.bytes_to_ptr(bytes_workspace)
        ret = acl.op.set_kernel_workspace_sizes(op_kernel_desc, self.workspace_num, workspace_sizes_ptr)
        check_ret("acl.op.set_kernel_workspace_sizes", ret)

    def run(self, op_type, inputs_data, inputs_info, outputs_info, attrs=None, device_id=0, op_model_path=None):
        self.init_resource(device_id)
        inputs_data = load_data(inputs_data)
        self.load_op(op_type, op_model_path)
        self.run_info = None
        self._gen_input_tensor(inputs_data, inputs_info)
        self._gen_output_tensor(outputs_info)
        if op_model_path:
            self._gen_attr(attrs)
            ret = acl.prof.init(op_model_path)
            check_ret("acl.prof.init", ret)
            config = acl.prof.create_config([device_id], 1, 0, ACL_PROF_ACL_API | ACL_PROF_TASK_TIME |
                                            ACL_PROF_AICPU_TRACE | ACL_PROF_AICORE_METRICS)
            acl.prof.start(config)
            self._forward(op_type, op_model_path)
            acl.prof.stop(config)
            ret = acl.prof.finalize()
            check_ret("acl.prof.finalize", ret)
        else:
            self._forward(op_type, op_model_path)

        result = self._get_operator_result(outputs_info)

        self.release(device_id)
        if op_model_path:
            result = save_data(op_model_path, result, start_name="output")
        return result

    def _forward(self, op_type, op_model_path):
        print('execute stage:')
        if op_model_path:
            ret = acl.op.execute_v2(op_type, self._inputs_desc, self._inputs_device_buffer, self.output_desc,
                                    self.device_buffer_outputs, self.op_attr, self.stream)
        else:
            ret = acl.op.update_params(op_type, self._inputs_desc, self.output_desc, self.op_attr)
            check_ret("acl.op.update_params", ret)
            ret = acl.op.execute_v2(op_type, self._inputs_desc, self._inputs_device_buffer, self.output_desc,
                                    self.device_buffer_outputs, self.op_attr, self.stream)

        check_ret("acl.op.execute_v2", ret)
        ret = acl.rt.synchronize_stream(self.stream)
        check_ret("acl.rt.synchronize_stream", ret)
        print("execute success")

    def _get_operator_result(self, outputs_info):
        print("get operator result stage:")
        result = []
        for index in range(len(self.output_desc)):
            factor = self.output_desc[index]
            info = outputs_info[index]
            factor_size = acl.get_tensor_desc_size(factor)
            ret = acl.rt.memcpy(self.host_outputs[index], factor_size, self.device_outputs[index], factor_size,
                                ACL_MEMCPY_DEVICE_TO_HOST)
            check_ret("acl.rt.memcpy", ret)

            data_shape = info["shape"]
            data_type = info["type"]
            np_dtype = np.dtype(data_type)
            data_len = functools.reduce(lambda x, y: x * y, data_shape)
            size = data_len * np_dtype.itemsize
            byte_data = acl.util.ptr_to_bytes(self.host_outputs[index], size)

            np_arr = np.frombuffer(bytearray(byte_data[:data_len * np_dtype.itemsize]), dtype=np_dtype, count=data_len)
            np_arr = np_arr.reshape(data_shape)
            result.append(np_arr)
        print("get operator result success")
        return result


acl_op = AclOp()


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='acl run parameters.')
    parser.add_argument('--op_type', type=str)
    parser.add_argument('--inputs_path', type=str)
    parser.add_argument('--inputs_info', type=str)
    parser.add_argument('--outputs_info', type=str)
    parser.add_argument('--attrs', type=str)
    parser.add_argument('--device_id', type=int)
    parser.add_argument('--op_model_path', type=str)
    parser.add_argument('--is_dynamic', type=int)

    args = parser.parse_args()
    op_type = args.op_type
    inputs_path = eval(args.inputs_path)
    inputs_info = eval(args.inputs_info)
    outputs_info = eval(args.outputs_info)
    attrs = eval(args.attrs)
    device_id = int(args.device_id)
    op_model_path = args.op_model_path
    is_dynamic = int(args.is_dynamic)

    acl_op.is_dynamic = is_dynamic
    acl_op.run(op_type, inputs_path, inputs_info, outputs_info, attrs, device_id, op_model_path)
